# ROUND 3, 3/18/24, adding in second bootstrap (of TP population)
# import packages and define functions
import pandas as pd
import numpy as np
import statsmodels.api as sm
import math
import sys
import matplotlib.pyplot as plt

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
        coef = model.params[pbvarb]
        se = model.bse[pbvarb]
        black_aud_rate = model.params[0] + model.params[1]
        non_black_aud_rate = model.params[0] 
        print("number of obs :" + str(model.nobs))
        return coef, se, black_aud_rate, non_black_aud_rate

## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        black_aud_rate = (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()
        non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()
        est =  black_aud_rate - non_black_aud_rate
        return est, black_aud_rate, non_black_aud_rate

## Get standard errors
def getSEs(dataset,pbvarb='predicted_prob_black',outcome='audited', seReg = 0.0):
        seMultiplier=getSEMultiplier(dataset,pbvarb)
        #seReg = regEstimate(dataset,pbvarb,outcome)[1]
        seChen = seReg*seMultiplier
        return seReg, seChen

def getSEMultiplier(dataset,pbvarb='predicted_prob_black'):
        return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))

### Operational audits and predicted race probabilities
keep_cols=['taxpayer_id', 'taxpayer_id_new', 'isEIC', 'predicted_prob_black', 'aud_no_research_audits']
dataBISG = pd.read_csv("/REDACTED/data/final/individualBISG2014_full_final_new_eic.csv", usecols=keep_cols)
len(dataBISG)
dataBISG= dataBISG[~dataBISG.predicted_prob_black.isna()]
len(dataBISG)

dataBISG_slim = dataBISG[['taxpayer_id_new', 'aud_no_research_audits', 'isEIC']]
'''
# set up empty lists to write out audit disparity results
full_pop_overall_lin = []
full_pop_overall_prob = []
full_pop_black_lin = []
full_pop_black_prob = []
full_pop_non_black_lin = []
full_pop_non_black_prob = []
full_pop_lin_se = []
full_pop_prob_se = []
full_pop_n = []

eitc_overall_lin = []
eitc_overall_prob = []
eitc_black_lin = []
eitc_black_prob = []
eitc_non_black_lin = []
eitc_non_black_prob = []
eitc_lin_se = []
eitc_prob_se = []
eitc_n = []

non_eitc_overall_lin = []
non_eitc_overall_prob = []
non_eitc_black_lin = []
non_eitc_black_prob = []
non_eitc_non_black_lin = []
non_eitc_non_black_prob = []
non_eitc_lin_se = []
non_eitc_prob_se = []
non_eitc_n = []

# loop 100 times, each time assigning BIFSG posteriors from a different iteration of the first bootstrapping procedure,
# and then bootstrapping the taxpayer population, writaxpayer_idg out audit disparity results
from timeit import default_timer as timer
for i in range(100):
    start = timer()
    print(i+1)
    # read in new iteration of posteriors from first bootstrapping procedure
    dt_tmp = pd.read_stata('/REDACTED/BIFSG/Bootstraps_V2/dual_boot_results/dual_loop_' + str(i+1) + '.dta.gz')
    col = 'p_black_' + str(i+1)
    dt_tmp['taxpayer_id_new'] = dt_tmp['taxpayer_id']
    dt_tmp_slim = dt_tmp[[col, 'taxpayer_id_new']]
    # merge posteriors onto large individual dataset
    dataBISG_m = dataBISG_slim.merge(dt_tmp_slim, on='taxpayer_id_new', how='left') # should add a check that merge works
    dataBISG_m = dataBISG_m[dataBISG_m[col].notnull()]
    dataBISG_tmp = dataBISG_m[[col, 'aud_no_research_audits', 'isEIC']]
    eitc_tmp = dataBISG_tmp.loc[dataBISG_tmp['isEIC'] == 1]
    non_eitc_tmp = dataBISG_tmp.loc[dataBISG_tmp['isEIC'] == 0]
    dataBISG_tmp = dataBISG_tmp[[col, 'aud_no_research_audits']]
    eitc_tmp = eitc_tmp[[col, 'aud_no_research_audits']]
    non_eitc_tmp = non_eitc_tmp[[col, 'aud_no_research_audits']]
    # bootstrap each dataset
    dataBISG_boot = dataBISG_tmp.sample(frac = 1, random_state = i, replace = True)
    eitc_boot = eitc_tmp.sample(frac = 1, random_state = i, replace = True)
    non_eitc_boot = non_eitc_tmp.sample(frac = 1, random_state = i, replace = True)
    # compute audit disparities, and write the results out
    linDisparity = regEstimate(dataBISG_boot, pbvarb = col, outcome = 'aud_no_research_audits')
    probDisparity = chenEstimate(dataBISG_boot, pbvarb = col, outcome = 'aud_no_research_audits')
    linSE, probSE = getSEs(dataBISG_boot, pbvarb = col, outcome = 'aud_no_research_audits', seReg = float(linDisparity[1]))
    full_pop_overall_lin.append(linDisparity[0])
    full_pop_overall_prob.append(probDisparity[0])
    full_pop_black_lin.append(linDisparity[2])
    full_pop_black_prob.append(probDisparity[1])
    full_pop_non_black_lin.append(linDisparity[3])
    full_pop_non_black_prob.append(probDisparity[2])
    full_pop_lin_se.append(linSE)
    full_pop_prob_se.append(probSE)
    full_pop_n.append(len(dataBISG_boot))
    linDisparity_eic = regEstimate(eitc_boot, pbvarb = col, outcome = 'aud_no_research_audits')
    probDisparity_eic = chenEstimate(eitc_boot, pbvarb = col, outcome = 'aud_no_research_audits')
    linSE_eic, probSE_eic = getSEs(eitc_boot, pbvarb = col, outcome = 'aud_no_research_audits', seReg = float(linDisparity_eic[1]))
    eitc_overall_lin.append(linDisparity_eic[0])
    eitc_overall_prob.append(probDisparity_eic[0])
    eitc_black_lin.append(linDisparity_eic[2])
    eitc_black_prob.append(probDisparity_eic[1])
    eitc_non_black_lin.append(linDisparity_eic[3])
    eitc_non_black_prob.append(probDisparity_eic[2])
    eitc_lin_se.append(linSE_eic)
    eitc_prob_se.append(probSE_eic)
    eitc_n.append(len(eitc_boot))
    linDisparity_non_eic = regEstimate(non_eitc_boot, pbvarb = col, outcome = 'aud_no_research_audits')
    probDisparity_non_eic = chenEstimate(non_eitc_boot, pbvarb = col, outcome = 'aud_no_research_audits')
    linSE_non_eic, probSE_non_eic = getSEs(non_eitc_boot, pbvarb = col, outcome = 'aud_no_research_audits', seReg = float(linDisparity_non_eic[1]))
    non_eitc_overall_lin.append(linDisparity_non_eic[0])
    non_eitc_overall_prob.append(probDisparity_non_eic[0])
    non_eitc_black_lin.append(linDisparity_non_eic[2])
    non_eitc_black_prob.append(probDisparity_non_eic[1])
    non_eitc_non_black_lin.append(linDisparity_non_eic[3])
    non_eitc_non_black_prob.append(probDisparity_non_eic[2])
    non_eitc_lin_se.append(linSE_non_eic)
    non_eitc_prob_se.append(probSE_non_eic)
    non_eitc_n.append(len(non_eitc_boot))
    # compile lists into a dataframe and write to a csv
    disp_df = pd.DataFrame(list(zip(full_pop_overall_lin, full_pop_overall_prob, full_pop_black_lin, full_pop_black_prob, full_pop_non_black_lin, full_pop_non_black_prob, full_pop_lin_se, full_pop_prob_se, full_pop_n, eitc_overall_lin, eitc_overall_prob, eitc_black_lin, eitc_black_prob, eitc_non_black_lin, eitc_non_black_prob, eitc_lin_se, eitc_prob_se, eitc_n, non_eitc_overall_lin, non_eitc_overall_prob, non_eitc_black_lin, non_eitc_black_prob, non_eitc_non_black_lin, non_eitc_non_black_prob, non_eitc_lin_se, non_eitc_prob_se, non_eitc_n)), columns = ['full_pop_overall_lin', 'full_pop_overall_prob', 'full_pop_black_lin', 'full_pop_black_prob', 'full_pop_non_black_lin', 'full_pop_non_black_prob', 'full_pop_lin_se', 'full_pop_prob_se', 'full_pop_n', 'eitc_overall_lin', 'eitc_overall_prob', 'eitc_black_lin', 'eitc_black_prob', 'eitc_non_black_lin', 'eitc_non_black_prob', 'eitc_lin_se', 'eitc_prob_se', 'eitc_n', 'non_eitc_overall_lin', 'non_eitc_overall_prob', 'non_eitc_black_lin', 'non_eitc_black_prob', 'non_eitc_non_black_lin', 'non_eitc_non_black_prob', 'non_eitc_lin_se', 'non_eitc_prob_se', 'non_eitc_n'])
    disp_df.to_csv('/REDACTED/disparity_100bootstraps_final_V3.csv', index=False)
    end = timer()
    print("Time: "+str(end-start))
'''
# get point estimates for vertical line in the plots below
point_estimate_lin = regEstimate(dataBISG, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[0]
point_estimate_prob = chenEstimate(dataBISG, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[0]

# read in bootstrap results if not re-running
disp_df = pd.read_csv('/REDACTED/disparity_100bootstraps_final_V3.csv')

# plot distribution of full pop linear estimates
plt.hist(disp_df['full_pop_overall_lin']*100, bins=10)
plt.xlabel('Linear Disparity Estimate')
plt.ylabel('Frequency')
plt.axvline(point_estimate_lin*100, color = "Black", linestyle='--')
plt.xlim(1.07, 1.6)
plt.savefig('/REDACTED/linear_disparity_estimates_dual_bootstrap_V3.png')

plt.close()

# plot distribution of full pop probabilistic estimates
plt.hist(disp_df['full_pop_overall_prob']*100, bins=10)
plt.xlabel('Probabilistic Disparity Estimate')
plt.ylabel('Frequency')
plt.axvline(point_estimate_prob*100, color = "Black", linestyle='--')
plt.xlim(0.65, 1)
plt.savefig('/REDACTED/probabilistic_disparity_estimates_dual_bootstrap_V3.png')

plt.close()


# Appendix Table A.12

# get point estimates for EITC and Non-EITC populations
eitc_pop = dataBISG.loc[dataBISG['isEIC'] == 1]
non_eitc_pop = dataBISG.loc[dataBISG['isEIC'] == 0]

point_estimate_lin_eitc = regEstimate(eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[0]
point_estimate_prob_eitc = chenEstimate(eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[0]

point_estimate_lin_non_eitc = regEstimate(non_eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[0]
point_estimate_prob_non_eitc = chenEstimate(non_eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[0]

# write out dual boostrap confidence intervals, taking average of percentiles 3 and 4 and percentiles 97 and 98 for 95% CI
sys.stdout = open("/REDACTED/dual_bootstrap_confidence_intervals.txt", "w")
print("Full Population")
print("Linear (point estimate, confidence interval lower bound, confidence interval upper bound)")
print(point_estimate_lin)
print((disp_df.full_pop_overall_lin.sort_values().tolist()[2] + disp_df.full_pop_overall_lin.sort_values().tolist()[3])/2)
print((disp_df.full_pop_overall_lin.sort_values().tolist()[96] + disp_df.full_pop_overall_lin.sort_values().tolist()[97])/2)

print("Probabilistic (point estimate, confidence interval lower bound, confidence interval upper bound)")
print(point_estimate_prob)
print((disp_df.full_pop_overall_prob.sort_values().tolist()[2] + disp_df.full_pop_overall_prob.sort_values().tolist()[3])/2)
print((disp_df.full_pop_overall_prob.sort_values().tolist()[96] + disp_df.full_pop_overall_prob.sort_values().tolist()[97])/2)

print("N")
print(len(dataBISG))

print("\nEITC")
print("Linear (point estimate, confidence interval lower bound, confidence interval upper bound)")
print(point_estimate_lin_eitc)
print((disp_df.eitc_overall_lin.sort_values().tolist()[2] + disp_df.eitc_overall_lin.sort_values().tolist()[3])/2)
print((disp_df.eitc_overall_lin.sort_values().tolist()[96] + disp_df.eitc_overall_lin.sort_values().tolist()[97])/2)

print("Probabilistic (point estimate, confidence interval lower bound, confidence interval upper bound)")
print(point_estimate_prob_eitc)
print((disp_df.eitc_overall_prob.sort_values().tolist()[2] + disp_df.eitc_overall_prob.sort_values().tolist()[3])/2)
print((disp_df.eitc_overall_prob.sort_values().tolist()[96] + disp_df.eitc_overall_prob.sort_values().tolist()[97])/2)

print("N")
print(len(eitc_pop))

print("\nNon-EITC")
print("Linear (point estimate, confidence interval lower bound, confidence interval upper bound)")
print(point_estimate_lin_non_eitc)
print((disp_df.non_eitc_overall_lin.sort_values().tolist()[2] + disp_df.non_eitc_overall_lin.sort_values().tolist()[3])/2)
print((disp_df.non_eitc_overall_lin.sort_values().tolist()[96] + disp_df.non_eitc_overall_lin.sort_values().tolist()[97])/2)

print("Probabilistic (point estimate, confidence interval lower bound, confidence interval upper bound)")
print(point_estimate_prob_non_eitc)
print((disp_df.non_eitc_overall_prob.sort_values().tolist()[2] + disp_df.non_eitc_overall_prob.sort_values().tolist()[3])/2)
print((disp_df.non_eitc_overall_prob.sort_values().tolist()[96] + disp_df.non_eitc_overall_prob.sort_values().tolist()[97])/2)

print("N")
print(len(non_eitc_pop))

sys.stdout = sys.__stdout__
